-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scatter axis + gather axis primitives #1813
Conversation
c273d38
to
6fb1fff
Compare
Other benchmarks:
|
a9cb2c1
to
aadf66e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, very clean. I left only one minor comment on the copy in the CPU side.
I am wondering whether it makes sense for some of these to simply output non-contiguous arrays. Most ops in MLX output contiguous arrays ie we are greedily taking the hit as quickly as possible (with the exception of unary/binary etc). In this case we could always treat one of the two arrays (src or idx) as contiguous and adjust the output order accordingly.
Anyway, I guess it is a case rare enough to not matter so the above can be categorized as a rant.
mlx/backend/common/indexing.cpp
Outdated
auto& updates = inputs[2]; | ||
|
||
// Copy src into out (copy allocates memory for out) | ||
copy(src, out, CopyType::General); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I am missing something, this needs to change to something that figures out the copy type. Same goes for normal Scatter
. On the GPU side we have that already but I think it makes sense to go to common/copy.h
. It would also enable donation of src which would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, nice low hanging fruit!
@@ -35,6 +35,8 @@ make_jit_source(ternary_ops) | |||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) | |||
make_jit_source(scatter kernels/indexing.h) | |||
make_jit_source(gather kernels/indexing.h) | |||
make_jit_source(gather_axis) | |||
make_jit_source(scatter_axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
kernel_name += upd.flags().row_contiguous ? "c" : "nc"; | ||
kernel_name += idx.flags().row_contiguous ? "c" : "nc"; | ||
|
||
auto lib = d.get_library(lib_name, [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100% not against it but are we moving towards having things inline unless they are used in multiple places in which case we move them to kernels.h
? Or is it only for things that are always jitted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, right now the pattern is inline if it's always jitted and in kernels.h
otherwise (since jit_kernels.cpp
only gets included when JIT is enabled).
} | ||
|
||
lhs_indices = astype(lhs_indices, uint32, s); | ||
rhs_indices = astype(rhs_indices, uint32, s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙏
93dbfc9
to
eb9a9d5
Compare
eb9a9d5
to
199baf0
Compare
Add a
GatherAxis
andScatterAxis
primitive to supporttake_along_axis
andput_along_axis
.The
ScatterAxis
supports two reduce modes (none and sum). The sum is useful for the gradient ofGatherAxis
. Did not add more reduce modes to manage complexity. One can always useScatter
for the other modes or we can consider adding them in the future.Put the kernels in the JIT by default as they are pretty simple but have a lot of combinations.
Incidentally closes #1807
TODO: